import torch

a = torch.arange(3*3*2).reshape(3,3,2)
print('a\n', a)
b = a.prod(-1)
print('prod\n', b)
print('prod\n', b.shape)
print('prod\n', b.view(-1))
